{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "HRFAE.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "LLQNYveOD3jr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!git clone https://github.com/InterDigitalInc/HRFAE.git\n",
        "%cd HRFAE"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NcsZC87Nab_x",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%cd logs/001\n",
        "!./download.sh\n",
        "%cd ./../.."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3FeQyYohljVu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import argparse\n",
        "import os\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.utils.data as data\n",
        "import yaml\n",
        "\n",
        "import matplotlib\n",
        "matplotlib.use('agg')\n",
        "from matplotlib import pyplot as plt\n",
        "%matplotlib inline\n",
        "\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')\n",
        "\n",
        "from PIL import Image\n",
        "from torchvision import transforms, utils\n",
        "\n",
        "from datasets import *\n",
        "from nets import *\n",
        "from functions import *\n",
        "from trainer import *"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VVpLnIp3l3EX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "parser = argparse.ArgumentParser()\n",
        "parser.add_argument('--config', type=str, default='001', help='path to the config file.')\n",
        "parser.add_argument('--vgg_model_path', type=str, default='./models/dex_imdb_wiki.caffemodel.pt', help='pretrained age classifier')\n",
        "parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')\n",
        "parser.add_argument('--multigpu', type=bool, default=False, help='use multiple gpus')\n",
        "parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path')\n",
        "parser.add_argument('--img_path', type=str, default='./test/input/', help='test image path')\n",
        "parser.add_argument('--out_path', type=str, default='./test/output/', help='test output path')\n",
        "parser.add_argument('--target_age', type=int, default=55, help='Age transform target, interger value between 20 and 70')\n",
        "opts = parser.parse_known_args()[0]\n",
        "\n",
        "log_dir = os.path.join(opts.log_path, opts.config) + '/'\n",
        "if not os.path.exists(opts.out_path):\n",
        "    os.makedirs(opts.out_path)\n",
        "\n",
        "config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'))\n",
        "img_size = (config['input_w'], config['input_h'])\n",
        "\n",
        "# Initialize trainer\n",
        "trainer = Trainer(config)\n",
        "device = torch.device('cuda')\n",
        "trainer.to(device)\n",
        "\n",
        "# Load pretrained model \n",
        "if opts.checkpoint:\n",
        "    trainer.load_checkpoint(opts.checkpoint)\n",
        "else:\n",
        "    trainer.load_checkpoint(log_dir + 'checkpoint')\n",
        "\n",
        "def preprocess(img_name):\n",
        "    resize = transforms.Compose([\n",
        "            transforms.Resize(img_size),\n",
        "            transforms.ToTensor()\n",
        "            ])\n",
        "    normalize = transforms.Normalize(mean=[0.48501961, 0.45795686, 0.40760392], std=[1,1,1])\n",
        "    img_pil = Image.open(opts.img_path + img_name)\n",
        "    img_np = np.array(img_pil)\n",
        "    img = resize(img_pil)\n",
        "    if img.size(0) == 1:\n",
        "        img = torch.cat((img, img, img), dim = 0)\n",
        "    img = normalize(img)\n",
        "    return img"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9jN_GRDsrsdN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Set target age\n",
        "target_age = 28\n",
        "\n",
        "# Load test image\n",
        "img_list = os.listdir(opts.img_path)\n",
        "img_list.sort()\n",
        "\n",
        "with torch.no_grad():\n",
        "    for img_name in img_list:\n",
        "        if not img_name.endswith(('png', 'jpg', 'PNG', 'JPG')):\n",
        "            print('File ignored: ' + img_name)\n",
        "            continue\n",
        "        image_A = preprocess(img_name)\n",
        "        image_A = image_A.unsqueeze(0).to(device)\n",
        "\n",
        "        age_modif = torch.tensor(target_age).unsqueeze(0).to(device)\n",
        "        image_A_modif = trainer.test_eval(image_A, age_modif, target_age=target_age, hist_trans=True)  \n",
        "        utils.save_image(clip_img(image_A_modif), opts.out_path + img_name.split('.')[0] + '_age_' + str(target_age) + '.jpg')\n",
        "\n",
        "        # Plot manipulated image\n",
        "        img_out = np.array(Image.open(opts.out_path + img_name.split('.')[0] + '_age_' + str(target_age) + '.jpg'))\n",
        "        plt.axis('off')\n",
        "        plt.imshow(img_out)\n",
        "        plt.show() "
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}